package org.springframework.security.oauth.consumer.filter;

import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;

import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.junit.After;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.security.oauth.common.OAuthProviderParameter;
import org.springframework.security.oauth.consumer.AccessTokenRequiredException;
import org.springframework.security.oauth.consumer.BaseProtectedResourceDetails;
import org.springframework.security.oauth.consumer.OAuthConsumerSupport;
import org.springframework.security.oauth.consumer.OAuthConsumerToken;
import org.springframework.security.oauth.consumer.OAuthSecurityContextHolder;
import org.springframework.security.oauth.consumer.ProtectedResourceDetails;
import org.springframework.security.oauth.consumer.rememberme.NoOpOAuthRememberMeServices;
import org.springframework.security.oauth.consumer.rememberme.OAuthRememberMeServices;
import org.springframework.security.oauth.consumer.token.OAuthConsumerTokenServices;
import org.springframework.security.web.RedirectStrategy;

/**
 * @author Ryan Heaton
 */
@RunWith(MockitoJUnitRunner.class)
public class OAuthConsumerContextFilterTests {
	@Mock
	private ProtectedResourceDetails details;
	@Mock
	private HttpServletRequest request;
	@Mock
	private HttpServletResponse response;
	@Mock
	private FilterChain filterChain;
	@Mock
	private OAuthConsumerTokenServices tokenServices;
	@Mock
	private OAuthConsumerSupport support;

	/**
	 * tests getting the user authorization redirect URL.
	 */
	@Test
	public void testGetUserAuthorizationRedirectURL() throws Exception {
		OAuthConsumerContextFilter filter = new OAuthConsumerContextFilter();

		OAuthConsumerToken token = new OAuthConsumerToken();
		token.setResourceId("resourceId");
		token.setValue("mytoken");
		when(details.getUserAuthorizationURL()).thenReturn("https://user-auth/context?with=some&queryParams");
		when(details.isUse10a()).thenReturn(false);
		assertEquals(
				"https://user-auth/context?with=some&queryParams&oauth_token=mytoken&oauth_callback=urn%3A%2F%2Fcallback%3Fwith%3Dsome%26query%3Dparams",
				filter.getUserAuthorizationRedirectURL(details, token, "urn://callback?with=some&query=params"));

		when(details.getUserAuthorizationURL()).thenReturn("https://user-auth/context?with=some&queryParams");
		when(details.isUse10a()).thenReturn(true);
		assertEquals("https://user-auth/context?with=some&queryParams&oauth_token=mytoken",
				filter.getUserAuthorizationRedirectURL(details, token, "urn://callback?with=some&query=params"));
	}

	/**
	 * tests the filter.
	 */
	@Test
	public void testDoFilter() throws Exception {
		final OAuthRememberMeServices rememberMeServices = new NoOpOAuthRememberMeServices();
		final BaseProtectedResourceDetails resource = new BaseProtectedResourceDetails();
		resource.setId("dep1");

		OAuthConsumerContextFilter filter = new OAuthConsumerContextFilter() {
			@Override
			protected String getCallbackURL(HttpServletRequest request) {
				return "urn:callback";
			}

			@Override
			protected String getUserAuthorizationRedirectURL(ProtectedResourceDetails details,
					OAuthConsumerToken requestToken, String callbackURL) {
				return callbackURL + "&" + requestToken.getResourceId();
			}
		};
		filter.setRedirectStrategy(new RedirectStrategy() {
			public void sendRedirect(HttpServletRequest request, HttpServletResponse response, String url)
					throws IOException {
				response.sendRedirect(url);
			}
		});

		filter.setTokenServices(tokenServices);
		filter.setConsumerSupport(support);
		filter.setRememberMeServices(rememberMeServices);

		doThrow(new AccessTokenRequiredException(resource)).when(filterChain).doFilter(request, response);
		when(tokenServices.getToken("dep1")).thenReturn(null);
		when(request.getParameter("oauth_verifier")).thenReturn(null);
		when(response.encodeRedirectURL("urn:callback")).thenReturn("urn:callback?query");

		OAuthConsumerToken token = new OAuthConsumerToken();
		token.setAccessToken(false);
		token.setResourceId(resource.getId());
		when(support.getUnauthorizedRequestToken("dep1", "urn:callback?query")).thenReturn(token);

		filter.doFilter(request, response, filterChain);

		verify(filterChain).doFilter(request, response);
		verify(tokenServices).storeToken("dep1", token);
		verify(response).sendRedirect("urn:callback?query&dep1");
		verify(request,times(2)).setAttribute(anyString(), anyObject());
		reset(request,response,filterChain);

		doThrow(new AccessTokenRequiredException(resource)).when(filterChain).doFilter(request, response);
		when(tokenServices.getToken("dep1")).thenReturn(token);
		when(request.getParameter(OAuthProviderParameter.oauth_verifier.toString())).thenReturn("verifier");
		OAuthConsumerToken accessToken = new OAuthConsumerToken();
		when(support.getAccessToken(token, "verifier")).thenReturn(accessToken);
		when(response.isCommitted()).thenReturn(false);

		filter.doFilter(request, response, filterChain);

		verify(filterChain,times(2)).doFilter(request, response);
		verify(tokenServices).removeToken("dep1");
		verify(tokenServices).storeToken("dep1", accessToken);
		verify(request,times(2)).setAttribute(anyString(), anyObject());
	}

	@After
	public void tearDown() throws Exception {
		OAuthSecurityContextHolder.setContext(null);
	}
}
